#!/usr/bin/python

import code,pickle,sys,os,re,traceback,sqlite3,warnings
import matplotlib
if "DISPLAY" not in os.environ: matplotlib.use("AGG")
else: matplotlib.use("GTK")
from optparse import OptionParser
from pylab import *
import ocrolib
from ocrolib import docproc,Record,utils,fstutils
from scipy import stats

parser = OptionParser("""
usage: %prog [options] .../.../010001.png ...

Extract character images from text line image files using the cseg files
left by the recognizer; character images are labeled by their corresponding 
characters in the .txt files.  

You must run ocropus-lattices and ocropus-align first to obtain the cseg 
and txt files.  You can also manually create the cseg and txt files.
""")

parser.add_option("-G","--gtsuffix",help="ground truth suffix",default=None)
parser.add_option("-g","--nogeo",help="don't extract geometry",action="store_true")
parser.add_option("-o","--output",help="output file",default="chars.db")
parser.add_option("-u","--unmerged",help="unmerged output file",default=None)
parser.add_option("-n","--nomissegmented",help="output no missegmented characters",action="store_true")
parser.add_option("-r","--raw",help="output unlabeled characters",action="store_true")
parser.add_option("-a","--maxage",help="output missegmented",default=10000000,type="int")
parser.add_option("-D","--display",help="display chars",action="store_true")
parser.add_option("-v","--verbose",help="verbose output",action="store_true")
parser.add_option("-N","--nosource",help="do not record source info",action="store_true")
parser.add_option("-c","--cont",help="continue even if errors are found",action="store_true")
parser.add_option("-E","--extracost",help="cost for misaligned characters",type=float,default=0.0)
parser.add_option("-P","--perc",help="cost percentile",type=float,default=90.0)
parser.add_option("-M","--maxperc",help="maximum cost at percentile",type=float,default=2.0)
parser.add_option("-A","--maxavg",help="maximum average cost",type=float,default=3.0)

(options,args) = parser.parse_args()
args = ocrolib.expand_args(args)

if len(args)<1:
    parser.print_help()
    sys.exit(0)

ion()
show()

if os.path.exists(options.output):
    print options.output,"exists; please remove"
    sys.exit(1)

table = "chars"

db = sqlite3.connect(options.output,timeout=600.0)
db.row_factory = utils.DbRow
db.text_factory = sqlite3.OptimizedUnicode
db.execute("pragma synchronous=0")
utils.charcolumns(db,table)
db.commit()

class BadGroundTruth(Exception):
    pass

def read_costs(fname):
    result = []
    with ocrolib.fopen(fname,"costs") as stream:
        for line in stream.readlines():
            f = line.split()
            i = int(f[0])-1
            result.append((i,float(f[1])))
    n = max(result)[0]+1
    costs = zeros(n)
    for i in range(len(result)):
        costs[result[i][0]] = result[i][1]
    return costs

def cseg_chars_all(file,suffix=options.gtsuffix):
    image = ocrolib.read_image_gray(file)
    image = 255-image
    cseg_file = ocrolib.ffind(file,"cseg")
    rseg_file = ocrolib.ffind(file,"rseg")
    cseg = ocrolib.read_line_segmentation(cseg_file)
    rseg = ocrolib.read_line_segmentation(rseg_file)
    if amax(rseg)<3:
        print "%s: not enough segments"%file
        return
    with ocrolib.fopen(file,"txt") as stream: gt = stream.read()
    gt = gt.decode("utf-8")
    raw = gt
    gt = re.sub('\n','',gt)
    gt = fstutils.explode_transcription(gt)
    costs = read_costs(file)
    while len(gt)>0 and gt[-1]==" ":
        gt = gt[:-1]
        costs = costs[:-1]

    perc = stats.scoreatpercentile(costs[1:],options.perc)
    avg = mean(costs[1:])
    skip = perc>options.maxperc or avg>options.maxavg
    if skip: 
        print "%s: cost too high"%file
        return

    if len(gt)>len(costs):
        print "%s: gt and costs differ: %d != %d"%(file,len(gt),len(costs))
        print "###",file,amax(cseg)
        print "###",len(gt),gt
        print "###",len(costs),costs
        return
    if not options.nogeo:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            geo = docproc.seg_geometry(rseg)
    else:
        geo = None
    grouper = ocrolib.StandardGrouper()
    grouper.setSegmentationAndGt(rseg,cseg,gt)
    for i in range(grouper.length()):
        index = grouper.getGtIndex(i)
        cls = grouper.getGtClass(i)
        cost= costs[index-1] if index>0 else options.extracost
        if cls==' ': continue
        raw,mask = grouper.extractWithMask(image,i,1)
        yield Record(raw=raw,mask=mask,cls=cls,index=grouper.getGtIndex(i),
                     bbox=grouper.boundingBox(i),
                     bbox_m=grouper.bboxMath(i),lgeo=geo,cost=cost)

ntried = 0
nfiles = 0
total = 0

for file in args:
    segments = []
    text = ""
    ntried += 1
    try:
        if options.nomissegmented:
            items = list(cseg_chars(file))
        else:
            items = list(cseg_chars_all(file))
        nfiles += 1
    except IOError,e:
        print "ERROR",e
        print "# cseg for",file,"not found (got %d of %d files)"%(nfiles,ntried)
        continue
    except Exception,e:
        if options.cont:
            traceback.print_exc()
            print "#",file,"failed",e
            continue
        else:
            traceback.print_exc()
            raise e
    if options.verbose:
        print file,len(segments),len(text)
    index = 0
    for x in items:
        raw = x.raw
        mask = x.mask
        cls = x.cls
        cost = x.cost
        # print index,cls
        index += 1
        segments.append(raw)
        if options.display:
            clf(); gray(); imshow(raw); draw()
        text += cls
        if cls is None:
            # no ground truth
            if not options.raw: continue
            cls = "_"
        elif cls<=0 or cls=="":
            # missegmented
            if not options.missegmented: continue
            cls = "~"
        if raw.shape[0]>255 or raw.shape[1]>255: continue
        raw = raw/float(amax(raw))
        key = re.sub(r'^.*/(\d\d\d\d/)','\\1',file)
        key = re.sub(r'\.png$','',key)
        if options.verbose:
            print key,cls,raw.shape
        if options.nosource:
            utils.dbinsert(db,table,
                           image=utils.image2blob(raw),
                           cost=cost,
                           cls=cls,
                           count=1,
                           lgeo="%g %g %g"%x.lgeo)
        else:
            x0,y0,x1,y1 = x.bbox_m
            rel = docproc.rel_char_geom((y0,y1,x0,x1),x.lgeo)
            ry,rw,rh = rel
            assert rw>0 and rh>0
            utils.dbinsert(db,table,
                           image=utils.image2blob(raw),
                           cost=cost,
                           cls=cls,
                           count=1,
                           file=file,
                           lgeo="%g %g %g"%x.lgeo,
                           rel="%g %g %g"%rel,
                           bbox="%d %d %d %d"%x.bbox_m)
        total+=1
        if total%10000==0:
            print total,"chars"
            db.commit()


db.commit()
print total,"chars"
